"""Separate actions and triggers [25155145c545].

Revision ID: 25155145c545
Revises: 0.58.2
Create Date: 2024-05-16 11:29:53.341275

"""

from uuid import uuid4

import sqlalchemy as sa
import sqlmodel
from alembic import op

from zenml.utils.time_utils import utc_now

# revision identifiers, used by Alembic.
revision = "25155145c545"
down_revision = "0.58.2"
branch_labels = None
depends_on = None


def migrate_actions() -> None:
    """Migrate actions from the trigger table."""
    conn = op.get_bind()
    meta = sa.MetaData()
    meta.reflect(only=("trigger", "action"), bind=op.get_bind())
    trigger_table = sa.Table("trigger", meta)
    action_table = sa.Table("action", meta)

    triggers = conn.execute(
        sa.select(
            trigger_table.c.id,
            trigger_table.c.name,
            trigger_table.c.user_id,
            trigger_table.c.workspace_id,
            trigger_table.c.service_account_id,
            trigger_table.c.auth_window,
            trigger_table.c.action,
            trigger_table.c.action_subtype,
            trigger_table.c.action_flavor,
        )
    ).fetchall()

    now = utc_now()

    actions_to_insert = []
    trigger_updates = {}
    for trigger in triggers:
        action_id = str(uuid4()).replace("-", "")

        actions_to_insert.append(
            {
                "id": action_id,
                "workspace_id": trigger.workspace_id,
                "user_id": trigger.user_id,
                "created": now,
                "updated": now,
                "name": f"{trigger.name}_action",
                "description": f"Automatically migrated action for trigger {trigger.name}",
                "service_account_id": trigger.service_account_id,
                "auth_window": trigger.auth_window,
                "configuration": trigger.action,
                "flavor": trigger.action_flavor,
                "plugin_subtype": trigger.action_subtype,
            }
        )

        trigger_updates[trigger.id] = action_id

    op.bulk_insert(action_table, actions_to_insert)

    for trigger_id, action_id in trigger_updates.items():
        query = (
            sa.update(trigger_table)
            .where(trigger_table.c.id.is_(trigger_id))
            .values(action_id=action_id)
        )
        conn.execute(query)


def upgrade() -> None:
    """Upgrade database schema and/or data, creating a new revision."""
    # ### commands auto generated by Alembic - please adjust! ###
    op.create_table(
        "action",
        sa.Column(
            "workspace_id", sqlmodel.sql.sqltypes.GUID(), nullable=False
        ),
        sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
        sa.Column(
            "service_account_id", sqlmodel.sql.sqltypes.GUID(), nullable=False
        ),
        sa.Column("description", sa.TEXT(), nullable=True),
        sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
        sa.Column("created", sa.DateTime(), nullable=False),
        sa.Column("updated", sa.DateTime(), nullable=False),
        sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
        sa.Column("auth_window", sa.Integer(), nullable=False),
        sa.Column(
            "flavor", sqlmodel.sql.sqltypes.AutoString(), nullable=False
        ),
        sa.Column(
            "plugin_subtype",
            sqlmodel.sql.sqltypes.AutoString(),
            nullable=False,
        ),
        sa.Column("configuration", sa.LargeBinary(), nullable=False),
        sa.ForeignKeyConstraint(
            ["service_account_id"],
            ["user.id"],
            name="fk_action_service_account_id_user",
            ondelete="CASCADE",
        ),
        sa.ForeignKeyConstraint(
            ["user_id"],
            ["user.id"],
            name="fk_action_user_id_user",
            ondelete="SET NULL",
        ),
        sa.ForeignKeyConstraint(
            ["workspace_id"],
            ["workspace.id"],
            name="fk_action_workspace_id_workspace",
            ondelete="CASCADE",
        ),
        sa.PrimaryKeyConstraint("id"),
    )

    # Add the action_id column as nullable until we migrate the actions into
    # a separate table
    with op.batch_alter_table("trigger", schema=None) as batch_op:
        batch_op.add_column(
            sa.Column("action_id", sqlmodel.sql.sqltypes.GUID(), nullable=True)
        )

    migrate_actions()

    with op.batch_alter_table("trigger", schema=None) as batch_op:
        # Make the action_id column non-nullable now that we've inserted the
        # actions and the value is set for each row
        batch_op.alter_column(
            "action_id",
            existing_type=sqlmodel.sql.sqltypes.GUID(),
            existing_nullable=True,
            nullable=False,
        )

        batch_op.add_column(
            sa.Column("schedule", sa.LargeBinary(), nullable=True)
        )
        batch_op.alter_column(
            "event_source_id", existing_type=sa.CHAR(length=32), nullable=True
        )
        batch_op.drop_constraint(
            "fk_trigger_event_source_id_event_source", type_="foreignkey"
        )
        batch_op.drop_constraint(
            "fk_trigger_service_account_id_user", type_="foreignkey"
        )
        batch_op.create_foreign_key(
            "fk_trigger_event_source_id_event_source",
            "event_source",
            ["event_source_id"],
            ["id"],
            ondelete="SET NULL",
        )
        batch_op.create_foreign_key(
            "fk_trigger_action_id_action",
            "action",
            ["action_id"],
            ["id"],
            ondelete="CASCADE",
        )
        batch_op.drop_column("action_subtype")
        batch_op.drop_column("action_flavor")
        batch_op.drop_column("action")
        batch_op.drop_column("auth_window")
        batch_op.drop_column("service_account_id")

    # ### end Alembic commands ###


def downgrade() -> None:
    """Downgrade database schema and/or data back to the previous revision."""
    # ### commands auto generated by Alembic - please adjust! ###
    with op.batch_alter_table("trigger", schema=None) as batch_op:
        batch_op.add_column(
            sa.Column("service_account_id", sa.CHAR(length=32), nullable=False)
        )
        batch_op.add_column(
            sa.Column("auth_window", sa.INTEGER(), nullable=False)
        )
        batch_op.add_column(sa.Column("action", sa.BLOB(), nullable=False))
        batch_op.add_column(
            sa.Column("action_flavor", sa.VARCHAR(), nullable=False)
        )
        batch_op.add_column(
            sa.Column("action_subtype", sa.VARCHAR(), nullable=False)
        )
        batch_op.drop_constraint(
            "fk_trigger_action_id_action", type_="foreignkey"
        )
        batch_op.drop_constraint(
            "fk_trigger_event_source_id_event_source", type_="foreignkey"
        )
        batch_op.create_foreign_key(
            "fk_trigger_service_account_id_user",
            "user",
            ["service_account_id"],
            ["id"],
            ondelete="CASCADE",
        )
        batch_op.create_foreign_key(
            "fk_trigger_event_source_id_event_source",
            "event_source",
            ["event_source_id"],
            ["id"],
            ondelete="CASCADE",
        )
        batch_op.alter_column(
            "event_source_id", existing_type=sa.CHAR(length=32), nullable=False
        )
        batch_op.drop_column("schedule")
        batch_op.drop_column("action_id")

    op.drop_table("action")
    # ### end Alembic commands ###
